QTM 447 Lecture 18: Attention is All You Need

Kevin McAlister

March 20, 2025

\[ \newcommand\hbb{{\hat{\boldsymbol \beta}}} \newcommand\bb{{\boldsymbol \beta}} \newcommand\expn{{\frac{1}{N} \sum \limits_{i = 1}^N}} \newcommand\sumk{\sum \limits_{k = 1}^K} \newcommand\argminb{\underset{\bb}{\text{argmin }}} \newcommand\argmaxb{\underset{\bb}{\text{argmax }}} \newcommand\gtheta{\mathbf g(\boldsymbol \theta)} \newcommand\htheta{\mathbf H(\boldsymbol \theta)} \]

Seq2Seq Problems

Seq2Seq models are RNNs that convert an input sequence to output sequences

  • We’re going to restrict our attention to the unaligned variant

  • The input sequence and the output sequence are not necessarily of the same length and may not have any direct one-to-one correspondence

Most commonly seen as question and answer or translation

  • My screen is blank. \(\rightarrow\) Please check if the computer is plugged in.

  • Bless your little heart \(\rightarrow\) You are sorely mistaken

Seq2Seq Problems

The “simple” model:

Seq2Seq Problems

In the training case

  • We see “Bless Your Little Heart” and “You are Sorely Mistaken”. Train the model to maximize the probability that this translation occurs.

At test time

  • Given a prompt to translate “Bless Your Little Heart” from Southern to English, return “You Are Sorely Mistaken” token-by-token given a trained model.

In either case, the model (after embedding and decoding) returns a prediction of what word comes next in the decoder

  • Usually a probability vector over words in the vocabulary

  • Sample from this distribution to get the next token!

LM Odds and Ends

A few quick bites worth mentioning (way more info on these things posted to the Canvas site).

  • Not discussed in detail - more appropriate for a NLP class (QTM 340, for example)

Integer Encoding

  • Words are not numbers!

  • Tokenize inputs/outputs into words (or parts)

  • Map each unique token to an integer

  • Make explicit tokens for start of sentence and end of sentence

    • Computers don’t understand proper sentence structure!

    • Denoted here as <sos> and <eos>

    • Just another token. Not needed when start and end is clear!

Odds and Ends

Example:

<sos> Bless your little heart <eos> <sos> You are sorely mistaken <eos>

\[ [1,2,3,4,5,6,1,7,8,9,10,6] \]

Words are not continuous!

  • Treat like unordered categorical features

Odds and Ends

One-hot encoding:

For the entire vocabulary (all unique tokens), let each possible token be a feature. Let each token be its own vector.

  • Code as one if token is equal to unique vocab token!

Represent vectors as, \(\mathbf W\), a \(\text{Number of Tokens} \times \text{Size of Vocabulary}\) matrix

  • A big ol’ matrix of zeros and ones

  • Mostly sparse

  • Very high dimensional

Odds and Ends

Word Embeddings

Take the one-hot matrix and project the points onto a lower dimensional subspace

\[ \underset{(N \times P)}{\mathbf X} = \underset{(N \times M)}{\mathbf W} \underset{(M \times P)}{\mathbf D} \]

  • \(\mathbf W\) is the one-hot encoded matrix of words

  • \(\mathbf D\) is a projection matrix

  • \(N\) input tokens, \(M\) tokens in vocabulary

Odds and Ends

Learned via Embeddings layers

or

Borrowed:

  • word2vec

  • GLoVE

  • BERT (discuss more next class)

Today, we’re going to be pretty agnostic about the embedding method

  • Just assume that the input is \(N\) vectors (one for each token) of length \(P\) arranged as a dense \(N \times P\) matrix

Seq2Seq Problems

The “simple” model:

Seq2Seq Problems

The hidden state updates as we move through the sequence

  • For seq2seq, we don’t actually need to track in the next input token! All we care about is getting our desired answer!

The hidden state of the encoder is used as the input for the decoder

  • Let it ride - just continue the recurrent sequence

Remember the issue with this from last time?

Seq2Seq Problems

The decoder may do better if it is allowed to look at all hidden states in the encoder instead of just looking at the final state

  • All decoder hidden values carry information

  • Every output token is related to the input directly

Instead, allow the decoder to use the encoder with attention

Attention

Attention

At \(t = 1\) of the decoder:

  • Start with the input embedding for <sos>

  • Using the previous decoder hidden state, \(\mathbf s_0\), find the dot product between the previous hidden state and all encoder hidden states

  • Convert these dot products to attention weights

  • Make the context vector a convex combination of all encoder hidden states weighted by the attention weights

  • Proceed like a RNN

Attention

Attention allows the decoder to find relevant parts of the input for determining what should come next in the decoder.

Input:

My computer won’t turn on

Output:

Is…

[it,she,there,…]

After Is:

My computer won’t turn on

Output:

Is…

[it,she,there,…]

Attention

With attention, we’re allowed to update context in the decoder!

A step forward from “simple” seq2seq models

Problems:

  • Slow - each update must be done iteratively

  • Very deep - unrolling this entire process into a feedforward style model shows that the number of layers is really high; one for each input and output token

  • Relatively poor memory - the encoder still requires only looking one step back and can lose older information

Attention

A warning: this is all going to move relatively quickly

  • Some steps to get from seq2seq with attention to transformers

Just keep in mind that autodiff will be able to backprop through everything here

  • TensorFlow/PyTorch will handle it all!

Move towards a sentient machine that understands the complexities of human language/images

  • No code today. All theoretical. Not too useful until we put all the pieces together.

  • Applications to come next class.

Attention

Attention

Given the attention setup, do we need our encoder to be recurrent?

  • The decoder looks back at every step each time

  • Does it matter if we understand the sequential nature of the input?

Yes:

  • Context of the input is sequential - different words make sense in the context of other input words

  • Need to know that “Little” in “Bless Your Little Heart” is pejorative instead of a descriptor!

No:

  • The decoder doesn’t really care

Recurrence is only needed to find the hidden states of the encoder!

Attention

Attention

Self-Attention

Allow the encoder to develop context of the input by looking at all other words in the input

Self-Attention

Self-Attention

Self-Attention

Self-Attention

In words:

  • For each word in the input, compute query-key-value sets (linear transformations of the input embeddings)

  • For each word \(i\):

    • Compute the dot product similarity between the query for \(i\) and all keys \(j \in T_e\) - \(\mathbf q_i^T \mathbf k_j\)

    • Softmax these similarities to get attention weights for each words - \(\mathbf w_i = [w_{i1},w_{i2},...,w_{iT_e}]\)

    • Compute the output, \(\mathbf o_i\), as the weighted combination of each \(v_j\) and the corresponding attention weight \(w_{ij}\)

The encoder looks forwards and backwards to see which words of the input correspond to each input word!

  • Context!

Self-Attention

Self-Attention

Self-Attention

Each self-attention operation will correspond to one notion of context:

  • Because \(\rightarrow\) [didn’t, cross] (what)

But, there are often layers of context:

  • Because \(\rightarrow\) [wide] (why)

Allow each layer of context to be uncovered using multiple self-attention operators

Self-Attention

Multiheaded Self Attention:

Self-Attention

Self-Attention

For each self-attention head:

  • Concatenate all of the representations - \(M\) dimensional inputs correspond to \(M \times H\) matrix of outputs for each input state

  • Use a standard flat neural network with ReLU activations to map \(M \times H\) to a single \(M\)-vector.

  • The overall context is a weighted combination of the \(H\) self-attentions contexts

This looks a lot like a standard flat NN

  • Could we improve our contextual understanding of the input by stacking self attention layers?

Self-Attention

Self-Attention

Each layer of self-attention can be thought of like a CNN:

  • The bottom layers pick up on very specific contexts (subject to verb, word-to-word relations)

  • Higher layers widen the picked up contextual features (tone, pronouns, etc.)

  • The top levels combine all previous layers of context to get broad subject matter context (a sentence about an animal, a pejorative statement about one’s naive view of the world)

Very powerful contextual machine with no recurrence

  • No order necessary!

Self-Attention

Self-Attention

But order is important!

  • Positional context

  • A context word one word away is different from one 10 words away…

  • The self-attention framework ignores this positional importance

Solution: create both word embeddings and positional embeddings and concatenate them as the input!

Positional Embeddings

Each word is associated with a position

\[ \underset{0}{\text{<sos>}}\underset{1}{\text{ bless}}\underset{2}{\text{ your}}\underset{3}{\text{ little}}\underset{4}{\text{ heart}}\underset{5}{\text{ <eos>}} \]

Ideal: Just pass integer encodings

  • Integers aren’t continuous

Learn from integer encodings?

  • Could work. But, really expensive.

Not absolute position that matters, but relative position!

Positional Embeddings

Instead, come up with a hard coding scheme that links word positions to relative distance

Reversed Binary Encoding:

\[ \begin{array}{c|ccc}\text{Token} & b_0 & b_1 & b_2 \\ \hline\text{<sos>} & 0 & 0 & 0 \\\text{bless} & 1 & 0 & 0 \\\text{your} & 0 & 1 & 0 \\\text{little}& 1 & 1 & 0 \\\text{heart} & 0 & 0 & 1 \\\text{<eos>} & 1 & 0 & 1 \\\end{array} \]

Note that the left most column flips the fastest, so distance in each column left to right indicates distance!

Get’s a little unruly, so use a different basis

Positional Embeddings

More compact representation - sinusoidal embeddings of dimension \(D\):

\[ \tiny \mathbf p_i = \left[\sin \left( \frac{i}{C^{0/D}} \right),\cos \left( \frac{i}{C^{0/D}} \right),\sin \left( \frac{i}{C^{2/D}} \right),\cos \left( \frac{i}{C^{2/D}} \right),...,\sin \left( \frac{i}{C^{(D - 2)/D}} \right),\cos \left( \frac{i}{C^{(D - 2)/D}} \right),\right] \]

  • Similar structure to binary encoding

  • Allows for continuous mapping

  • Does encode distance! Given a distance, we know how far the two vectors should be from one another

  • Works quite well.

Positional Embeddings

Self-Attention

Self-Attention

Self-attention looks forwards and backwards in the input to develop context as a function of the input sentence

  • Really, a collection of training sentences

Using this non-recurrent self-attention framework for the encoder, we get a different seq2seq model!

Quick note: we can also make regular attention (linking the encoder to the decoder) multiheaded

Self-Attention

Self-Attention

Training Time vs. Test Time

When training, we see:

[Bless,Your,Little,Heart, <eos>] \(\rightarrow\) [<sos>,You,Are,Sorely,Mistaken,<eos>]

Using the model to predict a translation, we see:

[Bless,Your,Little,Heart, <eos>] \(\rightarrow\) [<sos>]

Why won’t regular self-attention work for us in the decoder?

Self-Attention

Masked Self-Attention

A solution: Don’t let the Self-Attention in the decoder block look ahead!

Easiest way to do this:

At time \(t\), force all self-attention weights for tokens \(t + 1\) through \(T_d\) to be exactly zero!

  • Really quite simple to do in the attention setup

  • Still differentiable, so still handled by TensorFlow/PyTorch

Masked Self-Attention

Masked Self-Attention

Use masked self-attention to build context within the available tokens of the output sentence!

  • No recurrence necessary at this step

However, we’ll still need recurrence for the encoder/decoder attention

  • Have to allow the output sequence to evolve over time with information from the encoder

Masked Self-Attention

Masked Self-Attention

Masked Self-Attention

Masked Self-Attention

Transformers

This is the transformer

  • A self-attending encoder

  • A mostly self-attending decoder

Replace recurrence with positional encodings and multi-headed self-attention

A few additional bells and whistles to make everything tick

Transformers

“Attention is all you need” (Vaswani et al., 2017)

Transformers

  • Just as good as SoTA recurrent models

  • Way lower training cost (orders of magnitude lower)

Transformers

Better performance links to the vanishing gradient problem

Remember deeper \(=\) harder gradient to compute

For RNNs:

\[ h_t = \text{tanh}(\mathbf W_{hh} \mathbf h_{t-1} + \mathbf W_{xh} \mathbf x_t) \]

  • Each input/output token corresponds to a fully connected layer!

  • “Bless your little heart <eos> <sos> You are sorely mistaken <eos>” is an 11 layer network!

  • A 200 token question and 100 token answer will require a 300 layer NN!

Transformers

For transformers, layers correspond to attention blocks:

  • Encoder self-attention block corresponds to 2 layers (self-attention and FCNN)

  • Decoder masked self-attention corresponds to 4ish layers (masked self-attention + FCNN + attention + FCNN)

GPT-1 only had 12 decoder layers

  • 48 layers regardless of token size

Tradeoff is way more parameters

Transformers

Parameters are not a problem, though, due to clever parallel design

Transformers

Assume we have 1 million GPUs that can do dot products and cross products

1,000 tokens

  • Each GPU is responsible for one \(i\) subpart - \(\mathbf q_i\), \(\mathbf k_i\), \(\mathbf v_i\) (one computation for 3,000 GPUs)

  • Then, each GPU is responsible for one dot product - \(\mathbf q_i^T \mathbf k_j\) (one computation for 1 million GPUs)

  • Finally, each GPU normalizes the dot products via softmax for one \(i\) (one computation for 1,000 GPUs)

With no communication overhead (not trivial!!!!), computing self-attention for a 1,000 token sequence takes the same amount of time as computing it for 1 token since each dumb processor only has to do one dumb little task

Transformers

In contrast, a recurrent seq2seq model requires sequential processing at every step

  • We can’t compute \(\mathbf h_2\) until we know the value of \(\mathbf h_1\)

  • There is no real clever parallelization

This is why GPUs have gotten so expensive in the AI boom (supply and demand, baby!)

  • Each RTX 5090 GPU has 21,760 cores that can act independently

  • Roughly 60 5090s needed to get 1 million GPU cores

  • At 2k per GPU (lol), $120,000 to build a state of the art transformer machine that takes literal milliseconds to perform self-attention

  • If electricity is free…

Transformers

Parallelization is key here for the next steps with transformers

  • There is still one part of the transformer that is not parallelizable in training

  • Remember, recurrence = not parallelizable

Which step?

Transformers

Transformers

Smart ML folks looking for three additional zeroes on their series B funding check asked an important question:

Do we even need an encoder and a decoder?

  • This was the billion dollar question

The translation of encoder to decoder was the computational bottleneck preventing truly large language models

Solution:

Just use one side or the other!

The Great Schism

The bridge to Large Language Models required getting rid of one side of the transformer

  • Eliminate the need for cross-attention

  • If there are infinite GPUs, we can train models with an arbitrarily large number of parameters in finite time!

Which side to eliminate depends on the task.

The battle rages on today (sort of)…

The Great Schism

BERT

Representational Learning via Bidirectional Encoder Representations (BERT; Google, 2018)

Input: Input Tokens

Output: Hidden States

  • The model can see all timesteps

  • No inherent output tokens

  • No inherent auto-regressivity

  • No cross-attention steps!

Can be adapted to generate tokens and provide Chatbot style Q and A!

  • More next class.

GPT

Generative Modeling via Generative Pretrained Transformers (GPT; OpenAI, 2018)

Input: Input and (sort of) Output Tokens

Output: Sequential output tokens

This one is more in line with today’s lecture, so let’s look at this more closely!

GPT

GPT works by getting rid of cross-attention and repeatedly applying masked self-attention blocks to the input

  • Take in the entire question/answer and create a model that sequentially reveals the next token correctly with high probability

Unlike representational learning that sees all timesteps, force the model to learn the next token sequentially by feeding it back in to the model!

  • Training loop forces auto-regressive prediction without ever needing to bake in recurrence!

GPT

Each sentence is sequentially masked, revealing one new token each time

Input 1:

Translate from Southern to English. Bless your little heart. You are sorely [mask] [mask] [mask]….

Output 1:

Mistaken

Input 2:

Translate from Southern to English. Bless your little heart. You are sorely mistaken [mask] [mask]….

Output 2:

<eos>

GPT

GPT

GPT

The idea: with a large enough training corpus, pick up on high-dimensional patterns in language

  • Show it enough things (say scraped legally/legally from the entirety of the world wide web), it can use a trained series of masked self attention blocks to say what should come next

  • What is the context? Then, see what comes next with high probability!

Given everything we discussed today, a pretty simple model

GPT

Scalability:

  • Since all input tokens are seen at training time, we can parallelize the masked self attention part

  • More compute-intensive than self-attention since it needs to know what attention weights to block out at each step

At test time, no parallelizability

  • Have to go from input to input + 1 to input + 2

  • Not an issue if you have billions of investment dollars to host GPU servers

A modern miracle that ChatGPT only takes a few seconds to spout coherent answers!

GPT

A (paraphrased) quote from Sam Altman:

It is true that language models are just programmed to predict the next token. But that isn’t as simple as you might think. In fact, all animals, including us, are just programmed to survive and reproduce, and yet amazingly complex and beautiful stuff comes from it.

  • Is GPT close to sentient? Don’t we all just learn language by mimicking patterns we learn as a child?

Next Class

  • Quick overview of BERT

  • Vision Transformers/Pixel CNNs for image generation via autoregressive models!